diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fe7966ea9..1cadc0339 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,6 +76,9 @@ jobs: - name: Run Milvus run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:1.1.0-cpu-d050721-5e559c + - name: Run Weaviate + run: docker run -d -p 8080:8080 --name haystack_test_weaviate --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' semitechnologies/weaviate:1.4.0 + - name: Run GraphDB run: docker run -d -p 7200:7200 --name haystack_test_graphdb deepset/graphdb-free:9.4.1-adoptopenjdk11 diff --git a/docs/_src/usage/usage/document_store.md b/docs/_src/usage/usage/document_store.md index 43baecdb0..bd4e5b3f4 100644 --- a/docs/_src/usage/usage/document_store.md +++ b/docs/_src/usage/usage/document_store.md @@ -116,6 +116,27 @@ from haystack.document_store import SQLDocumentStore document_store = SQLDocumentStore() ``` + + + +
+ + +
+ +The `WeaviateDocumentStore` requires a running Weaviate Server. +You can start a basic instance like this (see Weaviate docs for details): +``` + docker run -d -p 8080:8080 --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' semitechnologies/weaviate:1.4.0 +``` + +Afterwards, you can use it in Haystack: +```python +from haystack.document_store import WeaviateDocumentStore + +document_store = WeaviateDocumentStore() +``` +
@@ -264,6 +285,24 @@ The Document Stores have different characteristics. You should choose one depend + +
+ + +
+ +**Pros:** +- Simple vector search +- Stores everything in one place: documents, meta data and vectors - so less network overhead when scaling this up +- Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset + +**Cons:** +- Less options for ANN algorithms than FAISS or Milvus +- No BM25 / Tf-idf retrieval + +
+
+
@@ -276,4 +315,4 @@ The Document Stores have different characteristics. You should choose one depend **Vector Specialist:** Use the `MilvusDocumentStore`, if you want to focus on dense retrieval and possibly deal with larger datasets -
\ No newline at end of file + diff --git a/haystack/document_store/weaviate.py b/haystack/document_store/weaviate.py new file mode 100644 index 000000000..e654a520d --- /dev/null +++ b/haystack/document_store/weaviate.py @@ -0,0 +1,716 @@ +import logging +from typing import Any, Dict, Generator, List, Optional, Union +import numpy as np +from tqdm import tqdm + +from haystack import Document +from haystack.document_store.base import BaseDocumentStore +from haystack.utils import get_batches_from_generator + +from weaviate import client, auth, AuthClientPassword +from weaviate import ObjectsBatchRequest + +logger = logging.getLogger(__name__) + +class WeaviateDocumentStore(BaseDocumentStore): + """ + + Weaviate is a cloud-native, modular, real-time vector search engine built to scale your machine learning models. + (See https://www.semi.technology/developers/weaviate/current/index.html#what-is-weaviate) + + Some of the key differences in contrast to FAISS & Milvus: + 1. Stores everything in one place: documents, meta data and vectors - so less network overhead when scaling this up + 2. Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset + 3. Has less variety of ANN algorithms, as of now only HNSW. + + Weaviate python client is used to connect to the server, more details are here + https://weaviate-python-client.readthedocs.io/en/docs/weaviate.html + + Usage: + 1. Start a Weaviate server (see https://www.semi.technology/developers/weaviate/current/getting-started/installation.html) + 2. Init a WeaviateDocumentStore in Haystack + """ + + def __init__( + self, + host: Union[str, List[str]] = "http://localhost", + port: Union[int, List[int]] = 8080, + timeout_config: tuple = (5, 15), + username: str = None, + password: str = None, + index: str = "Document", + embedding_dim: int = 768, + text_field: str = "text", + name_field: str = "name", + faq_question_field = "question", + similarity: str = "dot_product", + index_type: str = "hnsw", + custom_schema: Optional[dict] = None, + return_embedding: bool = False, + embedding_field: str = "embedding", + progress_bar: bool = True, + duplicate_documents: str = 'overwrite', + **kwargs, + ): + """ + :param host: Weaviate server connection URL for storing and processing documents and vectors. + For more details, refer "https://www.semi.technology/developers/weaviate/current/getting-started/installation.html" + :param port: port of Weaviate instance + :param timeout_config: Weaviate Timeout config as a tuple of (retries, time out seconds). + :param username: username (standard authentication via http_auth) + :param password: password (standard authentication via http_auth) + :param index: Index name for document text, embedding and metadata (in Weaviate terminology, this is a "Class" in Weaviate schema). + :param embedding_dim: The embedding vector size. Default: 768. + :param text_field: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text"). + If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned. + :param name_field: Name of field that contains the title of the the doc + :param faq_question_field: Name of field containing the question in case of FAQ-Style QA + :param similarity: The similarity function used to compare document vectors. 'dot_product' is the default. + :param index_type: Index type of any vector object defined in weaviate schema. The vector index type is pluggable. + Currently, HSNW is only supported. + See: https://www.semi.technology/developers/weaviate/current/more-resources/performance.html + :param custom_schema: Allows to create custom schema in Weaviate, for more details + See https://www.semi.technology/developers/weaviate/current/data-schema/schema-configuration.html + :param module_name : Vectorization module to convert data into vectors. Default is "text2vec-trasnformers" + For more details, See https://www.semi.technology/developers/weaviate/current/modules/ + :param return_embedding: To return document embedding. + :param embedding_field: Name of field containing an embedding vector. + :param progress_bar: Whether to show a tqdm progress bar or not. + Can be helpful to disable in production deployments to keep the logs clean. + :param duplicate_documents:Handle duplicates document based on parameter options. + Parameter options : ( 'skip','overwrite','fail') + skip: Ignore the duplicates documents + overwrite: Update any existing documents with the same ID when adding documents. + fail: an error is raised if the document ID of the document being added already exists. + """ + + # save init parameters to enable export of component config as YAML + self.set_config( + host=host, port=port, timeout_config=timeout_config, username=username, password=password, + index=index, embedding_dim=embedding_dim, text_field=text_field, name_field=name_field, + faq_question_field=faq_question_field, similarity=similarity, index_type=index_type, + custom_schema=custom_schema,return_embedding=return_embedding, embedding_field=embedding_field, + progress_bar=progress_bar, duplicate_documents=duplicate_documents + ) + + # Connect to Weaviate server using python binding + weaviate_url =f"{host}:{port}" + if username and password: + secret = AuthClientPassword(username, password) + self.weaviate_client = client.Client(url=weaviate_url, + auth_client_secret=secret, + timeout_config=timeout_config) + else: + self.weaviate_client = client.Client(url=weaviate_url, + timeout_config=timeout_config) + + # Test Weaviate connection + try: + status = self.weaviate_client.is_ready() + if not status: + raise ConnectionError( + f"Initial connection to Weaviate failed. Make sure you run Weaviate instance " + f"at `{weaviate_url}` and that it has finished the initial ramp up (can take > 30s)." + ) + except Exception: + raise ConnectionError( + f"Initial connection to Weaviate failed. Make sure you run Weaviate instance " + f"at `{weaviate_url}` and that it has finished the initial ramp up (can take > 30s)." + ) + self.index = index + self.embedding_dim = embedding_dim + self.text_field = text_field + self.name_field = name_field + self.faq_question_field = faq_question_field + self.similarity = similarity + self.index_type = index_type + self.custom_schema = custom_schema + self.return_embedding = return_embedding + self.embedding_field = embedding_field + self.progress_bar = progress_bar + self.duplicate_documents = duplicate_documents + + self._create_schema_and_index_if_not_exist(self.index) + + def _create_schema_and_index_if_not_exist( + self, + index: Optional[str] = None, + ): + """Create a new index (schema/class in Weaviate) for storing documents in case if an index (schema) with the name doesn't exist already.""" + index = index or self.index + + if self.custom_schema: + schema = self.custom_schema + else: + schema = { + "classes": [ + { + "class": index, + "description": "Haystack index, it's a class in Weaviate", + "invertedIndexConfig": { + "cleanupIntervalSeconds": 60 + }, + "vectorizer": "none", + "properties": [ + { + "dataType": [ + "string" + ], + "description": "Name Field", + "name": self.name_field + }, + { + "dataType": [ + "string" + ], + "description": "Question Field", + "name": self.faq_question_field + }, + { + "dataType": [ + "text" + ], + "description": "Document Text", + "name": self.text_field + }, + ], + } + ] + } + if not self.weaviate_client.schema.contains(schema): + self.weaviate_client.schema.create(schema) + + def _convert_weaviate_result_to_document( + self, + result: dict, + return_embedding: bool + ) -> Document: + """ + Convert weaviate result dict into haystack document object. This is more involved because + weaviate search result dict varies between get and query interfaces. + Weaviate get methods return the data items in properties key, whereas the query doesn't. + """ + score = None + probability = None + text = "" + question = None + + id = result.get("id") + embedding = result.get("vector") + + # If properties key is present, get all the document fields from it. + # otherwise, a direct lookup in result root dict + props = result.get("properties") + if not props: + props = result + + if props.get(self.text_field) is not None: + text = str(props.get(self.text_field)) + + if props.get(self.faq_question_field) is not None: + question = props.get(self.faq_question_field) + + # Weaviate creates "_additional" key for semantic search + if "_additional" in props: + if "certainty" in props["_additional"]: + score = props["_additional"]['certainty'] + probability = score + if "id" in props["_additional"]: + id = props["_additional"]['id'] + if "vector" in props["_additional"]: + embedding = props["_additional"]['vector'] + props.pop("_additional", None) + + # We put all additional data of the doc into meta_data and return it in the API + meta_data = {k:v for k,v in props.items() if k not in (self.text_field, self.faq_question_field, self.embedding_field)} + + if return_embedding and embedding: + embedding = np.asarray(embedding, dtype=np.float32) + + document = Document( + id=id, + text=text, + meta=meta_data, + score=score, + probability=probability, + question=question, + embedding=embedding, + ) + return document + + def _create_document_field_map(self) -> Dict: + return { + self.text_field: "text", + self.embedding_field: "embedding", + self.faq_question_field if self.faq_question_field else "question": "question" + } + + def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]: + """Fetch a document by specifying its text id string""" + # Sample result dict from a get method + '''{'class': 'Document', + 'creationTimeUnix': 1621075584724, + 'id': '1bad51b7-bd77-485d-8871-21c50fab248f', + 'properties': {'meta': "{'key1':'value1'}", + 'name': 'name_5', + 'text': 'text_5'}, + 'vector': []}''' + index = index or self.index + document = None + result = self.weaviate_client.data_object.get_by_id(id, with_vector=True) + if result: + document = self._convert_weaviate_result_to_document(result, return_embedding=True) + return document + + def get_documents_by_id(self, ids: List[str], index: Optional[str] = None, + batch_size: int = 10_000) -> List[Document]: + """Fetch documents by specifying a list of text id strings""" + index = index or self.index + documents = [] + #TODO: better implementation with multiple where filters instead of chatty call below? + for id in ids: + result = self.weaviate_client.data_object.get_by_id(id, with_vector=True) + if result: + document = self._convert_weaviate_result_to_document(result, return_embedding=True) + documents.append(document) + return documents + + def _get_current_properties(self, index: Optional[str] = None) -> List[str]: + """Get all the existing properties in the schema""" + index = index or self.index + cur_properties = [] + for class_item in self.weaviate_client.schema.get()['classes']: + if class_item['class'] == index: + cur_properties = [item['name'] for item in class_item['properties']] + + return cur_properties + + def _build_filter_clause(self, filters:Dict[str, List[str]]) -> dict: + """Transform Haystack filter conditions to Weaviate where filter clauses""" + weaviate_filters = [] + weaviate_filter = {} + for key, values in filters.items(): + for value in values: + weaviate_filter = { + "path": [key], + "operator": "Equal", + "valueString": value + } + weaviate_filters.append(weaviate_filter) + if len(weaviate_filters) > 1: + filter_dict = { + "operator": "Or", + "operands": weaviate_filters + } + return filter_dict + else: + return weaviate_filter + + def _update_schema(self, new_prop:str, index: Optional[str] = None): + """Updates the schema with a new property""" + index = index or self.index + property_dict = { + "dataType": [ + "string" + ], + "description": f"dynamic property {new_prop}", + "name": new_prop + } + self.weaviate_client.schema.property.create(index, property_dict) + + def _check_document(self, cur_props: List[str], doc: dict) -> List[str]: + """Find the properties in the document that don't exist in the existing schema""" + return [item for item in doc.keys() if item not in cur_props] + + def write_documents( + self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, + batch_size: int = 10_000, duplicate_documents: Optional[str] = None): + """ + Add new documents to the DocumentStore. + + :param documents: List of `Dicts` or List of `Documents`. Passing an Embedding/Vector is mandatory in case weaviate is not + configured with a module. If a module is configured, the embedding is automatically generated by Weaviate. + :param index: index name for storing the docs and metadata + :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + :param duplicate_documents: Handle duplicates document based on parameter options. + Parameter options : ( 'skip','overwrite','fail') + skip: Ignore the duplicates documents + overwrite: Update any existing documents with the same ID when adding documents. + fail: an error is raised if the document ID of the document being added already + exists. + :raises DuplicateDocumentError: Exception trigger on duplicate document + :return: None + """ + index = index or self.index + self._create_schema_and_index_if_not_exist(index) + field_map = self._create_document_field_map() + + duplicate_documents = duplicate_documents or self.duplicate_documents + assert duplicate_documents in self.duplicate_documents_options, \ + f"duplicate_documents parameter must be {', '.join(self.duplicate_documents_options)}" + + if len(documents) == 0: + logger.warning("Calling DocumentStore.write_documents() with empty list") + return + + # Auto schema feature https://github.com/semi-technologies/weaviate/issues/1539 + # Get and cache current properties in the schema + current_properties = self._get_current_properties(index) + + document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents] + document_objects = self._handle_duplicate_documents(document_objects, duplicate_documents) + + batched_documents = get_batches_from_generator(document_objects, batch_size) + with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar: + for document_batch in batched_documents: + docs_batch = ObjectsBatchRequest() + for idx, doc in enumerate(document_batch): + _doc = { + **doc.to_dict(field_map=self._create_document_field_map()) + } + _ = _doc.pop("score", None) + _ = _doc.pop("probability", None) + + # In order to have a flat structure in elastic + similar behaviour to the other DocumentStores, + # we "unnest" all value within "meta" + if "meta" in _doc.keys(): + for k, v in _doc["meta"].items(): + _doc[k] = v + _doc.pop("meta") + + doc_id = str(_doc.pop("id")) + vector = _doc.pop(self.embedding_field) + if _doc.get(self.faq_question_field) is None: + _doc.pop(self.faq_question_field) + + # Check if additional properties are in the document, if so, + # append the schema with all the additional properties + missing_props = self._check_document(current_properties, _doc) + if missing_props: + for property in missing_props: + self._update_schema(property, index) + current_properties.append(property) + + docs_batch.add(_doc, class_name=index, uuid=doc_id, vector=vector) + + # Ingest a batch of documents + results = self.weaviate_client.batch.create(docs_batch) + # Weaviate returns errors for every failed document in the batch + if results is not None: + for result in results: + if 'result' in result and 'errors' in result['result'] \ + and 'error' in result['result']['errors']: + for message in result['result']['errors']['error']: + logger.error(f"{message['message']}") + progress_bar.update(batch_size) + progress_bar.close() + + def update_document_meta(self, id: str, meta: Dict[str, str]): + """ + Update the metadata dictionary of a document by specifying its string id + """ + self.weaviate_client.data_object.update(meta, class_name=self.index, uuid=id) + + def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int: + """ + Return the number of documents in the document store. + """ + index = index or self.index + doc_count = 0 + if filters: + filter_dict = self._build_filter_clause(filters=filters) + result = self.weaviate_client.query.aggregate(index) \ + .with_fields("meta { count }") \ + .with_where(filter_dict)\ + .do() + else: + result = self.weaviate_client.query.aggregate(index)\ + .with_fields("meta { count }")\ + .do() + + if "data" in result: + if "Aggregate" in result.get('data'): + doc_count = result.get('data').get('Aggregate').get(index)[0]['meta']['count'] + + return doc_count + + def get_all_documents( + self, + index: Optional[str] = None, + filters: Optional[Dict[str, List[str]]] = None, + return_embedding: Optional[bool] = None, + batch_size: int = 10_000, + ) -> List[Document]: + """ + Get documents from the document store. + + :param index: Name of the index to get the documents from. If None, the + DocumentStore's default index (self.index) will be used. + :param filters: Optional filters to narrow down the documents to return. + Example: {"name": ["some", "more"], "category": ["only_one"]} + :param return_embedding: Whether to return the document embeddings. + :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + """ + index = index or self.index + result = self.get_all_documents_generator( + index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size + ) + documents = list(result) + return documents + + def _get_all_documents_in_index( + self, + index: Optional[str], + filters: Optional[Dict[str, List[str]]] = None, + batch_size: int = 10_000, + only_documents_without_embedding: bool = False, + ) -> Generator[dict, None, None]: + """ + Return all documents in a specific index in the document store + """ + index = index or self.index + + # Build the properties to retrieve from Weaviate + properties = self._get_current_properties(index) + properties.append("_additional {id, certainty, vector}") + + if filters: + filter_dict = self._build_filter_clause(filters=filters) + result = self.weaviate_client.query.get(class_name=index, properties=properties)\ + .with_where(filter_dict)\ + .do() + else: + result = self.weaviate_client.query.get(class_name=index, properties=properties)\ + .do() + + all_docs = {} + if result and "data" in result and "Get" in result.get("data"): + if result.get("data").get("Get").get(index): + all_docs = result.get("data").get("Get").get(index) + + yield from all_docs + + def get_all_documents_generator( + self, + index: Optional[str] = None, + filters: Optional[Dict[str, List[str]]] = None, + return_embedding: Optional[bool] = None, + batch_size: int = 10_000, + ) -> Generator[Document, None, None]: + """ + Get documents from the document store. Under-the-hood, documents are fetched in batches from the + document store and yielded as individual documents. This method can be used to iteratively process + a large number of documents without having to load all documents in memory. + + :param index: Name of the index to get the documents from. If None, the + DocumentStore's default index (self.index) will be used. + :param filters: Optional filters to narrow down the documents to return. + Example: {"name": ["some", "more"], "category": ["only_one"]} + :param return_embedding: Whether to return the document embeddings. + :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + """ + + if index is None: + index = self.index + + if return_embedding is None: + return_embedding = self.return_embedding + + results = self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size) + for result in results: + document = self._convert_weaviate_result_to_document(result, return_embedding=return_embedding) + yield document + + def query( + self, + query: Optional[str] = None, + filters: Optional[Dict[str, List[str]]] = None, + top_k: int = 10, + custom_query: Optional[str] = None, + index: Optional[str] = None, + ) -> List[Document]: + """ + Scan through documents in DocumentStore and return a small number documents + that are most relevant to the query as defined by Weaviate semantic search. + + :param query: The query + :param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field + :param top_k: How many documents to return per query. + :param custom_query: Custom query that will executed using query.raw method, for more details refer + https://www.semi.technology/developers/weaviate/current/graphql-references/filters.html + :param index: The name of the index in the DocumentStore from which to retrieve documents + """ + index = index or self.index + + # Build the properties to retrieve from Weaviate + properties = self._get_current_properties(index) + properties.append("_additional {id, certainty, vector}") + + if custom_query: + query_output = self.weaviate_client.query.raw(custom_query) + elif filters: + filter_dict = self._build_filter_clause(filters) + query_output = self.weaviate_client.query\ + .get(class_name=index, properties=properties)\ + .with_where(filter_dict)\ + .with_limit(top_k)\ + .do() + else: + raise NotImplementedError("Weaviate does not support inverted index text query. However, " + "it allows to search by filters example : {'text': 'some text'} or " + "use a custom GraphQL query in text format!") + + results = [] + if query_output and "data" in query_output and "Get" in query_output.get("data"): + if query_output.get("data").get("Get").get(index): + results = query_output.get("data").get("Get").get(index) + + documents = [] + for result in results: + doc = self._convert_weaviate_result_to_document(result, return_embedding=True) + documents.append(doc) + + return documents + + def query_by_embedding(self, + query_emb: np.ndarray, + filters: Optional[dict] = None, + top_k: int = 10, + index: Optional[str] = None, + return_embedding: Optional[bool] = None) -> List[Document]: + """ + Find the document that is most similar to the provided `query_emb` by using a vector similarity metric. + + :param query_emb: Embedding of the query (e.g. gathered from DPR) + :param filters: Optional filters to narrow down the search space. + Example: {"name": ["some", "more"], "category": ["only_one"]} + :param top_k: How many documents to return + :param index: index name for storing the docs and metadata + :param return_embedding: To return document embedding + :return: + """ + if return_embedding is None: + return_embedding = self.return_embedding + index = index or self.index + + # Build the properties to retrieve from Weaviate + properties = self._get_current_properties(index) + properties.append("_additional {id, certainty, vector}") + + query_emb = query_emb.reshape(1, -1).astype(np.float32) + query_string = { + "vector" : query_emb + } + if filters: + filter_dict = self._build_filter_clause(filters) + query_output = self.weaviate_client.query\ + .get(class_name=index, properties=properties)\ + .with_where(filter_dict)\ + .with_near_vector(query_string)\ + .with_limit(top_k)\ + .do() + else: + query_output = self.weaviate_client.query\ + .get(class_name=index, properties=properties)\ + .with_near_vector(query_string)\ + .with_limit(top_k)\ + .do() + + results = [] + if query_output and "data" in query_output and "Get" in query_output.get("data"): + if query_output.get("data").get("Get").get(index): + results = query_output.get("data").get("Get").get(index) + + documents = [] + for result in results: + doc = self._convert_weaviate_result_to_document(result, return_embedding=return_embedding) + documents.append(doc) + + return documents + + def update_embeddings( + self, + retriever, + index: Optional[str] = None, + filters: Optional[Dict[str, List[str]]] = None, + update_existing_embeddings: bool = True, + batch_size: int = 10_000 + ): + """ + Updates the embeddings in the the document store using the encoding model specified in the retriever. + This can be useful if want to change the embeddings for your documents (e.g. after changing the retriever config). + + :param retriever: Retriever to use to update the embeddings. + :param index: Index name to update + :param update_existing_embeddings: Weaviate mandates an embedding while creating the document itself. + This option must be always true for weaviate and it will update the embeddings for all the documents. + :param filters: Optional filters to narrow down the documents for which embeddings are to be updated. + Example: {"name": ["some", "more"], "category": ["only_one"]} + :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + :return: None + """ + if index is None: + index = self.index + + if not self.embedding_field: + raise RuntimeError("Specify the arg `embedding_field` when initializing WeaviateDocumentStore()") + + if update_existing_embeddings: + logger.info(f"Updating embeddings for all {self.get_document_count(index=index)} docs ...") + else: + raise RuntimeError("All the documents in Weaviate store have an embedding by default. Only update is allowed!") + + result = self._get_all_documents_in_index( + index=index, + filters=filters, + batch_size=batch_size, + ) + + for result_batch in get_batches_from_generator(result, batch_size): + document_batch = [self._convert_weaviate_result_to_document(hit, return_embedding=False) for hit in result_batch] + embeddings = retriever.embed_passages(document_batch) # type: ignore + assert len(document_batch) == len(embeddings) + + if embeddings[0].shape[0] != self.embedding_dim: + raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})" + f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})." + "Specify the arg `embedding_dim` when initializing WeaviateDocumentStore()") + for doc, emb in zip(document_batch, embeddings): + # This doc processing will not required once weaviate's update + # method works. To be improved. + _doc = { + **doc.to_dict(field_map=self._create_document_field_map()) + } + _ = _doc.pop("score", None) + _ = _doc.pop("probability", None) + + if "meta" in _doc.keys(): + for k, v in _doc["meta"].items(): + _doc[k] = v + _doc.pop("meta") + + doc_id = str(_doc.pop("id")) + _ = _doc.pop(self.embedding_field) + keys_to_remove = [k for k,v in _doc.items() if v is None] + for key in keys_to_remove: + _doc.pop(key) + + # TODO: Weaviate's update throws an error while passing a vector now, have to improve this later + self.weaviate_client.data_object.replace(_doc, class_name=index, uuid=doc_id, vector=emb) + + def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None): + """ + Delete documents in an index. All documents are deleted if no filters are passed. + + :param index: Index name to delete the document from. + :param filters: Optional filters to narrow down the documents to be deleted. + :return: None + """ + index = index or self.index + if filters: + docs_to_delete = self.get_all_documents(index, filters=filters) + for doc in docs_to_delete: + self.weaviate_client.data_object.delete(doc.id) + else: + self.weaviate_client.schema.delete_class(index) + self._create_schema_and_index_if_not_exist(index) diff --git a/requirements.txt b/requirements.txt index 7303dae0a..c83ba516e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,5 @@ pymilvus #selenium #webdriver-manager SPARQLWrapper -mmh3 \ No newline at end of file +mmh3 +weaviate-client \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py index 42aac7e06..ceeb4a718 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -9,6 +9,9 @@ from elasticsearch import Elasticsearch from haystack.knowledge_graph.graphdb import GraphDBKnowledgeGraph from milvus import Milvus +import weaviate +from haystack.document_store.weaviate import WeaviateDocumentStore + from haystack.document_store.milvus import MilvusDocumentStore from haystack.generator.transformers import RAGenerator, RAGeneratorType @@ -61,6 +64,8 @@ def pytest_collection_modifyitems(items): item.add_marker(pytest.mark.pipeline) elif "slow" in item.nodeid: item.add_marker(pytest.mark.slow) + elif "weaviate" in item.nodeid: + item.add_marker(pytest.mark.weaviate) @pytest.fixture(scope="session") @@ -98,6 +103,27 @@ def milvus_fixture(): 'milvusdb/milvus:0.10.5-cpu-d010621-4eda95'], shell=True) time.sleep(40) +@pytest.fixture(scope="session") +def weaviate_fixture(): + # test if a Weaviate server is already running. If not, start Weaviate docker container locally. + # Make sure you have given > 6GB memory to docker engine + try: + weaviate_server = weaviate.Client(url='http://localhost:8080', timeout_config=(5, 15)) + weaviate_server.is_ready() + except: + print("Starting Weaviate servers ...") + status = subprocess.run( + ['docker rm haystack_test_weaviate'], + shell=True + ) + status = subprocess.run( + ['docker run -d --name haystack_test_weaviate -p 8080:8080 semitechnologies/weaviate:1.4.0'], + shell=True + ) + if status.returncode: + raise Exception( + "Failed to launch Weaviate. Please check docker container logs.") + time.sleep(60) @pytest.fixture(scope="session") def graphdb_fixture(): @@ -310,7 +336,6 @@ def document_store_with_docs(request, test_docs_xs): yield document_store document_store.delete_all_documents() - @pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"]) def document_store(request, test_docs_xs): document_store = get_document_store(request.param) @@ -352,6 +377,14 @@ def get_document_store(document_store_type, embedding_field="embedding"): if collection.startswith("haystack_test"): document_store.milvus_server.drop_collection(collection) return document_store + elif document_store_type == "weaviate": + document_store = WeaviateDocumentStore( + weaviate_url="http://localhost:8080", + index="Haystacktest" + ) + document_store.weaviate_client.schema.delete_all() + document_store._create_schema_and_index_if_not_exist() + return document_store else: raise Exception(f"No document store fixture for '{document_store_type}'") diff --git a/test/pytest.ini b/test/pytest.ini index 9da069d5b..fe36dbb18 100644 --- a/test/pytest.ini +++ b/test/pytest.ini @@ -8,3 +8,4 @@ markers = generator: marks generator tests (deselect with '-m "not generator"') pipeline: marks tests with pipeline summarizer: marks summarizer tests + weaviate: marks tests that require weaviate container diff --git a/test/test_weaviate.py b/test/test_weaviate.py new file mode 100644 index 000000000..56fc50f82 --- /dev/null +++ b/test/test_weaviate.py @@ -0,0 +1,330 @@ +import numpy as np +import pytest +from haystack import Document +from conftest import get_document_store +import uuid + +embedding_dim = 768 + +def get_uuid(): + return str(uuid.uuid4()) + +DOCUMENTS = [ + {"text": "text1", "id":get_uuid(), "key": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + {"text": "text2", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + {"text": "text3", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + {"text": "text4", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + {"text": "text5", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)}, +] + +DOCUMENTS_XS = [ + # current "dict" format for a document + {"text": "My name is Carla and I live in Berlin", "id":get_uuid(), "meta": {"metafield": "test1", "name": "filename1"}, "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + # meta_field at the top level for backward compatibility + {"text": "My name is Paul and I live in New York", "id":get_uuid(), "metafield": "test2", "name": "filename2", "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + # Document object for a doc + Document(text="My name is Christelle and I live in Paris", id=get_uuid(), meta={"metafield": "test3", "name": "filename3"}, embedding=np.random.rand(embedding_dim).astype(np.float32)) + ] + +@pytest.fixture(params=["weaviate"]) +def document_store_with_docs(request): + document_store = get_document_store(request.param) + document_store.write_documents(DOCUMENTS_XS) + yield document_store + document_store.delete_all_documents() + +@pytest.fixture(params=["weaviate"]) +def document_store(request): + document_store = get_document_store(request.param) + yield document_store + document_store.delete_all_documents() + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True) +def test_get_all_documents_without_filters(document_store_with_docs): + documents = document_store_with_docs.get_all_documents() + assert all(isinstance(d, Document) for d in documents) + assert len(documents) == 3 + assert {d.meta["name"] for d in documents} == {"filename1", "filename2", "filename3"} + assert {d.meta["metafield"] for d in documents} == {"test1", "test2", "test3"} + +@pytest.mark.weaviate +def test_get_all_documents_with_correct_filters(document_store_with_docs): + documents = document_store_with_docs.get_all_documents(filters={"metafield": ["test2"]}) + assert len(documents) == 1 + assert documents[0].meta["name"] == "filename2" + + documents = document_store_with_docs.get_all_documents(filters={"metafield": ["test1", "test3"]}) + assert len(documents) == 2 + assert {d.meta["name"] for d in documents} == {"filename1", "filename3"} + assert {d.meta["metafield"] for d in documents} == {"test1", "test3"} + +@pytest.mark.weaviate +def test_get_all_documents_with_incorrect_filter_name(document_store_with_docs): + documents = document_store_with_docs.get_all_documents(filters={"incorrectmetafield": ["test2"]}) + assert len(documents) == 0 + +@pytest.mark.weaviate +def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs): + documents = document_store_with_docs.get_all_documents(filters={"metafield": ["incorrect_value"]}) + assert len(documents) == 0 + +@pytest.mark.weaviate +def test_get_documents_by_id(document_store_with_docs): + documents = document_store_with_docs.get_all_documents() + doc = document_store_with_docs.get_document_by_id(documents[0].id) + assert doc.id == documents[0].id + assert doc.text == documents[0].text + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True) +def test_get_document_count(document_store): + document_store.write_documents(DOCUMENTS) + assert document_store.get_document_count() == 5 + assert document_store.get_document_count(filters={"key": ["a"]}) == 1 + assert document_store.get_document_count(filters={"key": ["b"]}) == 4 + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True) +@pytest.mark.parametrize("batch_size", [2]) +def test_weaviate_write_docs(document_store, batch_size): + # Write in small batches + for i in range(0, len(DOCUMENTS), batch_size): + document_store.write_documents(DOCUMENTS[i: i + batch_size]) + + documents_indexed = document_store.get_all_documents() + assert len(documents_indexed) == len(DOCUMENTS) + + documents_indexed = document_store.get_all_documents(batch_size=batch_size) + assert len(documents_indexed) == len(DOCUMENTS) + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True) +def test_get_all_document_filter_duplicate_value(document_store): + documents = [ + Document( + text="Doc1", + meta={"fone": "f0"}, + id = get_uuid(), + embedding= np.random.rand(embedding_dim).astype(np.float32) + ), + Document( + text="Doc1", + meta={"fone": "f1", "metaid": "0"}, + id = get_uuid(), + embedding = np.random.rand(embedding_dim).astype(np.float32) + ), + Document( + text="Doc2", + meta={"fthree": "f0"}, + id = get_uuid(), + embedding=np.random.rand(embedding_dim).astype(np.float32) + ) + ] + document_store.write_documents(documents) + documents = document_store.get_all_documents(filters={"fone": ["f1"]}) + assert documents[0].text == "Doc1" + assert len(documents) == 1 + assert {d.meta["metaid"] for d in documents} == {"0"} + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True) +def test_get_all_documents_generator(document_store): + document_store.write_documents(DOCUMENTS) + assert len(list(document_store.get_all_documents_generator(batch_size=2))) == 5 + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True) +def test_write_with_duplicate_doc_ids(document_store): + id = get_uuid() + documents = [ + Document( + text="Doc1", + id=id, + embedding=np.random.rand(embedding_dim).astype(np.float32) + ), + Document( + text="Doc2", + id=id, + embedding=np.random.rand(embedding_dim).astype(np.float32) + ) + ] + document_store.write_documents(documents, duplicate_documents="skip") + with pytest.raises(Exception): + document_store.write_documents(documents, duplicate_documents="fail") + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True) +@pytest.mark.parametrize("update_existing_documents", [True, False]) +def test_update_existing_documents(document_store, update_existing_documents): + id = uuid.uuid4() + original_docs = [ + {"text": "text1_orig", "id": id, "metafieldforcount": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + ] + + updated_docs = [ + {"text": "text1_new", "id": id, "metafieldforcount": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + ] + + document_store.update_existing_documents = update_existing_documents + document_store.write_documents(original_docs) + assert document_store.get_document_count() == 1 + + if update_existing_documents: + document_store.write_documents(updated_docs, duplicate_documents="overwrite") + else: + with pytest.raises(Exception): + document_store.write_documents(updated_docs, duplicate_documents="fail") + + stored_docs = document_store.get_all_documents() + assert len(stored_docs) == 1 + if update_existing_documents: + assert stored_docs[0].text == updated_docs[0]["text"] + else: + assert stored_docs[0].text == original_docs[0]["text"] + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True) +def test_write_document_meta(document_store): + uid1 = get_uuid() + uid2 = get_uuid() + uid3 = get_uuid() + uid4 = get_uuid() + documents = [ + {"text": "dict_without_meta", "id": uid1, "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + {"text": "dict_with_meta", "metafield": "test2", "name": "filename2", "id": uid2, "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + Document(text="document_object_without_meta", id=uid3, embedding=np.random.rand(embedding_dim).astype(np.float32)), + Document(text="document_object_with_meta", meta={"metafield": "test4", "name": "filename3"}, id=uid4, embedding=np.random.rand(embedding_dim).astype(np.float32)), + ] + document_store.write_documents(documents) + documents_in_store = document_store.get_all_documents() + assert len(documents_in_store) == 4 + + assert not document_store.get_document_by_id(uid1).meta + assert document_store.get_document_by_id(uid2).meta["metafield"] == "test2" + assert not document_store.get_document_by_id(uid3).meta + assert document_store.get_document_by_id(uid4).meta["metafield"] == "test4" + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True) +def test_write_document_index(document_store): + documents = [ + {"text": "text1", "id": uuid.uuid4(), "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + {"text": "text2", "id": uuid.uuid4(), "embedding": np.random.rand(embedding_dim).astype(np.float32)}, + ] + + document_store.write_documents([documents[0]], index="Haystackone") + assert len(document_store.get_all_documents(index="Haystackone")) == 1 + + document_store.write_documents([documents[1]], index="Haystacktwo") + assert len(document_store.get_all_documents(index="Haystacktwo")) == 1 + + assert len(document_store.get_all_documents(index="Haystackone")) == 1 + assert len(document_store.get_all_documents()) == 0 + +@pytest.mark.weaviate +@pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True) +@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True) +def test_update_embeddings(document_store, retriever): + documents = [] + for i in range(6): + documents.append({"text": f"text_{i}", "id": str(uuid.uuid4()), "metafield": f"value_{i}", "embedding": np.random.rand(embedding_dim).astype(np.float32)}) + documents.append({"text": "text_0", "id": str(uuid.uuid4()), "metafield": "value_0", "embedding": np.random.rand(embedding_dim).astype(np.float32)}) + + document_store.write_documents(documents, index="HaystackTestOne") + document_store.update_embeddings(retriever, index="HaystackTestOne", batch_size=3) + documents = document_store.get_all_documents(index="HaystackTestOne", return_embedding=True) + assert len(documents) == 7 + for doc in documents: + assert type(doc.embedding) is np.ndarray + + documents = document_store.get_all_documents( + index="HaystackTestOne", + filters={"metafield": ["value_0"]}, + return_embedding=True, + ) + assert len(documents) == 2 + for doc in documents: + assert doc.meta["metafield"] == "value_0" + np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4) + + documents = document_store.get_all_documents( + index="HaystackTestOne", + filters={"metafield": ["value_1", "value_5"]}, + return_embedding=True, + ) + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_equal, + documents[0].embedding, + documents[1].embedding + ) + + doc = {"text": "text_7", "id": str(uuid.uuid4()), "metafield": "value_7", + "embedding": retriever.embed_queries(texts=["a random string"])[0]} + document_store.write_documents([doc], index="HaystackTestOne") + + doc_before_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0] + embedding_before_update = doc_before_update.embedding + + document_store.update_embeddings( + retriever, index="HaystackTestOne", batch_size=3, filters={"metafield": ["value_0", "value_1"]} + ) + doc_after_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0] + embedding_after_update = doc_after_update.embedding + np.testing.assert_array_equal(embedding_before_update, embedding_after_update) + + # test update all embeddings + document_store.update_embeddings(retriever, index="HaystackTestOne", batch_size=3, update_existing_embeddings=True) + assert document_store.get_document_count(index="HaystackTestOne") == 8 + doc_after_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0] + embedding_after_update = doc_after_update.embedding + np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update) + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True) +def test_query_by_embedding(document_store_with_docs): + docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32)) + assert len(docs) == 3 + + docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32), + top_k=1) + assert len(docs) == 1 + + docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32), + filters = {"name": ['filename2']}) + assert len(docs) == 1 + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True) +def test_query(document_store_with_docs): + query_text = 'My name is Carla and I live in Berlin' + with pytest.raises(Exception): + docs = document_store_with_docs.query(query_text) + + docs = document_store_with_docs.query(filters = {"name": ['filename2']}) + assert len(docs) == 1 + + docs = document_store_with_docs.query(filters={"text":[query_text.lower()]}) + assert len(docs) == 1 + + docs = document_store_with_docs.query(filters={"text":['live']}) + assert len(docs) == 3 + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True) +def test_delete_all_documents(document_store_with_docs): + assert len(document_store_with_docs.get_all_documents()) == 3 + + document_store_with_docs.delete_all_documents() + documents = document_store_with_docs.get_all_documents() + assert len(documents) == 0 + +@pytest.mark.weaviate +@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True) +def test_delete_documents_with_filters(document_store_with_docs): + document_store_with_docs.delete_all_documents(filters={"metafield": ["test1", "test2"]}) + documents = document_store_with_docs.get_all_documents() + assert len(documents) == 1 + assert documents[0].meta["metafield"] == "test3" +